// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/pass/convert_power_to_powerstatic.hpp"

#include <memory>

#include "openvino/core/graph_util.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/power.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "snippets/itt.hpp"
#include "snippets/op/powerstatic.hpp"
#include "snippets/op/scalar.hpp"

ov::snippets::pass::ConvertPowerToPowerStatic::ConvertPowerToPowerStatic() {
    MATCHER_SCOPE(ConvertPowerToPowerStatic);
    auto scalarPower = std::make_shared<ov::pass::pattern::op::Label>(
        ov::pass::pattern::any_input(),
        [](const std::shared_ptr<Node>& n) {
            return is_type<ov::op::v1::Power>(n) && is_type<snippets::op::Scalar>(n->get_input_node_shared_ptr(1));
        });
    ov::graph_rewrite_callback callback = [](ov::pass::pattern::Matcher& m) {
        OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ConvertPowerToPowerStatic")
        auto power = ov::as_type_ptr<ov::op::v1::Power>(m.get_match_root());
        auto scalar = ov::as_type_ptr<snippets::op::Scalar>(power->get_input_node_shared_ptr(1));
        auto value = scalar->cast_vector<float>()[0];
        auto power_static = std::make_shared<snippets::op::PowerStatic>(power->input(0).get_source_output(), value);
        power_static->set_friendly_name(power->get_friendly_name());
        copy_runtime_info(power, power_static);
        ov::replace_node(power, power_static);

        return true;
    };
    register_matcher(std::make_shared<ov::pass::pattern::Matcher>(scalarPower, matcher_name), callback);
}
